/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package testing

import (
	"context"
	"sync/atomic"

	mock "github.com/stretchr/testify/mock"

	state "github.com/dapr/components-contrib/state"
)

// MockStateStore is an autogenerated mock type for the Store type
type MockStateStore struct {
	mock.Mock
}

func (_m *MockStateStore) BulkDelete(ctx context.Context, req []state.DeleteRequest, opts state.BulkStoreOpts) error {
	ret := _m.Called(ctx, req)

	var r0 error
	if rf, ok := ret.Get(0).(func(context.Context, []state.DeleteRequest) error); ok {
		r0 = rf(ctx, req)
	} else {
		r0 = ret.Error(0)
	}

	return r0
}

func (_m *MockStateStore) BulkSet(ctx context.Context, req []state.SetRequest, opts state.BulkStoreOpts) error {
	ret := _m.Called(ctx, req)

	var r0 error
	if rf, ok := ret.Get(0).(func(context.Context, []state.SetRequest) error); ok {
		r0 = rf(ctx, req)
	} else {
		r0 = ret.Error(0)
	}

	return r0
}

// Delete provides a mock function with given fields: req
func (_m *MockStateStore) Delete(ctx context.Context, req *state.DeleteRequest) error {
	ret := _m.Called(ctx, req)

	var r0 error
	if rf, ok := ret.Get(0).(func(context.Context, *state.DeleteRequest) error); ok {
		r0 = rf(ctx, req)
	} else {
		r0 = ret.Error(0)
	}

	return r0
}

// Get provides a mock function with given fields: req
func (_m *MockStateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
	ret := _m.Called(ctx, req)

	var r0 *state.GetResponse
	if rf, ok := ret.Get(0).(func(context.Context, *state.GetRequest) *state.GetResponse); ok {
		r0 = rf(ctx, req)
	} else if ret.Get(0) != nil {
		r0 = ret.Get(0).(*state.GetResponse)
	}

	var r1 error
	if rf, ok := ret.Get(1).(func(context.Context, *state.GetRequest) error); ok {
		r1 = rf(ctx, req)
	} else {
		r1 = ret.Error(1)
	}

	return r0, r1
}

func (_m *MockStateStore) BulkGet(ctx context.Context, req []state.GetRequest, opts state.BulkGetOpts) ([]state.BulkGetResponse, error) {
	return nil, nil
}

// Init provides a mock function with given fields: metadata
func (_m *MockStateStore) Init(ctx context.Context, metadata state.Metadata) error {
	ret := _m.Called(metadata)

	var r0 error
	if rf, ok := ret.Get(0).(func(state.Metadata) error); ok {
		r0 = rf(metadata)
	} else {
		r0 = ret.Error(0)
	}

	return r0
}

// Ping provides a mock function
func (_m *MockStateStore) Ping() error {
	return nil
}

// Set provides a mock function with given fields: req
func (_m *MockStateStore) Set(ctx context.Context, req *state.SetRequest) error {
	ret := _m.Called(ctx, req)

	var r0 error
	if rf, ok := ret.Get(0).(func(context.Context, *state.SetRequest) error); ok {
		r0 = rf(ctx, req)
	} else {
		r0 = ret.Error(0)
	}

	return r0
}

// Features returns the features for this state store.
func (_m *MockStateStore) Features() []state.Feature {
	return nil
}

func (_m *MockStateStore) Close() error {
	return nil
}

type FailingStatestore struct {
	Failure     Failure
	BulkFailKey atomic.Pointer[string]
}

func (f *FailingStatestore) BulkDelete(ctx context.Context, req []state.DeleteRequest, opts state.BulkStoreOpts) error {
	for _, val := range req {
		err := f.Failure.PerformFailure(val.Key)
		if err != nil {
			return err
		}
	}
	return nil
}

func (f *FailingStatestore) BulkSet(ctx context.Context, req []state.SetRequest, opts state.BulkStoreOpts) error {
	for _, val := range req {
		err := f.Failure.PerformFailure(val.Key)
		if err != nil {
			return err
		}
	}
	return nil
}

func (f *FailingStatestore) Delete(ctx context.Context, req *state.DeleteRequest) error {
	return f.Failure.PerformFailure(req.Key)
}

func (f *FailingStatestore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
	err := f.Failure.PerformFailure(req.Key)
	if err != nil {
		return nil, err
	}

	var res *state.GetResponse
	if req.Key != "nilGetKey" {
		res = &state.GetResponse{}
	}

	return res, nil
}

func (f *FailingStatestore) BulkGet(ctx context.Context, req []state.GetRequest, opts state.BulkGetOpts) ([]state.BulkGetResponse, error) {
	bfk := f.BulkFailKey.Load()
	if bfk != nil && *bfk != "" {
		err := f.Failure.PerformFailure(*bfk)
		if err != nil {
			return nil, err
		}
	}

	// Return keys one by one
	res := []state.BulkGetResponse{}
	for i := range req {
		r, err := f.Get(ctx, &req[i])
		if err != nil {
			res = append(res, state.BulkGetResponse{
				Key:   req[i].Key,
				Error: err.Error(),
			})
			continue
		}
		res = append(res, state.BulkGetResponse{
			Key:         req[i].Key,
			Data:        r.Data,
			ETag:        r.ETag,
			ContentType: r.ContentType,
			Metadata:    r.Metadata,
		})
	}
	return res, nil
}

func (f *FailingStatestore) Init(ctx context.Context, metadata state.Metadata) error {
	return nil
}

func (f *FailingStatestore) Ping() error {
	return nil
}

func (f *FailingStatestore) Set(ctx context.Context, req *state.SetRequest) error {
	return f.Failure.PerformFailure(req.Key)
}

func (f *FailingStatestore) Close() error {
	return nil
}
